import numpy as np
import torch
from torch import nn
from torchvision.datasets import CIFAR10,CIFAR100
from torchvision import transforms
from torch.utils.data import random_split
import matplotlib.pyplot as plt
from backbone import ResNet_size
from torch.utils.data import DataLoader
from data_utils import *

output_dim = 2
num_classes = 10
cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if cuda else 'cpu')

class Extension_CrossEntropy(torch.nn.Module):

    def __init__(self, dim = 2):
        super(Extension_CrossEntropy,self).__init__()
        self.a = 0
        self.p = 1

    def forward(self, x ,y):
        _,C = x.shape
        y = torch.nn.functional.one_hot(y, num_classes = x.shape[1])
        assert torch.isnan(torch.max(x)) == False
        x = torch.nn.functional.softmax(x,1)
        loss = torch.log(1 + self.a - torch.abs(x - y + 1e-20).pow(self.p)) - torch.log(torch.tensor(1 + self.a))
        loss = -torch.mean(torch.sum(loss, dim=1),dim = 0)/self.p
        return loss

class OntoEncoder(nn.Module):
    def __init__(self, size = 18, n_class = None, loss = 'crossentropy'):
        super(OntoEncoder,self).__init__()
        self.encoder = ResNet_size(size=size, in_channels=3, output_size=output_dim, encoder = 'Yes')
        self.n_class = n_class
        self.loss = loss
        self.expension = 1 + (size > 49)*3
        self.decoder = nn.Sequential(
            nn.Linear(512*self.expension, n_class),
        )
    def forward(self, x, mode = 'train'):
        semantic = self.encoder(x)
        y = nn.AdaptiveAvgPool2d((1,1))(semantic)
        y = y.view(-1,512*self.expension)
        out = self.decoder(y)
        if mode == 'train':
            return out
        elif mode == 'test':
            return nn.Softmax(1)(out)
    def training_step(self, img, label = None, embedding_gt = None):
        if embedding_gt == None:
            out = self.forward(img, 'trai')
            batch_size = label.shape[0]
            if self.loss == 'crossentropy':
                label = label.reshape(batch_size).to(device)
                loss_func = Extension_CrossEntropy()
                loss = loss_func(out.reshape(batch_size, self.n_class), label.long())
            elif self.loss == 'mse':
                print(out.shape)
                loss_func = nn.PairwiseDistance(p = 2)
                loss = torch.mean(loss_func(out, label) / self.n_class)
            self.encoder.requires_grad_ = True

        return loss
    
    def evaluate_step(self, val_loader):
        with torch.no_grad():
            self.eval()
            val_loss = []
            #val_acc = []
            for img, label in val_loader:
                # Calculate loss
                img = img.float()
                img = img.to(device)
                loss = self.training_step(img, label)
                val_loss.append(loss.item())

        epoch_loss = torch.tensor(val_loss).mean()  # Combine losses
        result = {'val_loss': epoch_loss.item()}
        return result
    
    def extract_embedding(self, dataset):
        size_batch = len(dataset)//30 #change this if not enough memory

        data_loader = DataLoader(dataset=dataset, batch_size=size_batch, shuffle=False)
        data_loader = DeviceDataLoader(data_loader, 'cuda')

        with torch.no_grad():
            self.train()
            embedding = torch.tensor([]).to('cuda')
            for batch in data_loader:
                data, _ = batch
                embedded = self.forward(data, 'test')
                embedding = torch.cat((embedding, embedded), 0)

        embedding = embedding.cpu().numpy()  # embedding.t().cpu().numpy()
        label = torch.tensor(dataset.targets).cpu().numpy()

        del data_loader, embedded, data
        torch.cuda.empty_cache()  # PyTorch thing

        return embedding, label
    
    def get_loss_type(self):
        return self.loss


class Pipeline(nn.Module):
    def __init__(self, size = 18, embedded_dim = 3, n_class = None, loss = 'crossentropy'):
        super(Pipeline,self).__init__()
        self.encoder = ResNet_size(size=size, in_channels=3, output_size=output_dim, encoder = 'Yes')
        self.n_class = n_class
        self.loss = loss
        self.expension = 1 + (size > 49)*3
        self.embedding = nn.Linear(512*self.expension, embedded_dim)
    def forward(self, x, mode = 'train'):
        semantic = self.encoder(x)
        y = nn.AdaptiveAvgPool2d((1,1))(semantic)
        y = y.view(-1,512*self.expension)
        embedding = self.embedding(y)
        if mode == 'train':
            return embedding
        elif mode == 'test':
            return nn.Softmax(1)(y)
    def training_step(self, img, label = None, embedding_gt = None):
        if embedding_gt == None:
            out = self.forward(img, 'train_encoder')
            batch_size = label.shape[0]
            label = label.reshape(batch_size).to(device)
            if self.loss == 'crossentropy':
                loss_func = Extension_CrossEntropy()
                loss = loss_func(out.reshape(batch_size, self.n_class), label.long())
        return loss
    
    def evaluate_step(self, val_loader):
        with torch.no_grad():
            self.eval()
            val_loss = []
            #val_acc = []
            for img, label in val_loader:
                # Calculate loss
                img = img.float()
                img = img.to(device)
                loss = self.training_step(img, label)
                val_loss.append(loss.item())

        epoch_loss = torch.tensor(val_loss).mean()  # Combine losses
        result = {'val_loss': epoch_loss.item()}
        return result
    
    def extract_embedding(self, dataset):
        size_batch = len(dataset)//30 #change this if not enough memory

        data_loader = DataLoader(dataset=dataset, batch_size=size_batch, shuffle=False)
        data_loader = DeviceDataLoader(data_loader, 'cuda')

        with torch.no_grad():
            self.train()
            embedding = torch.tensor([]).to('cuda')
            for batch in data_loader:
                data, _ = batch
                embedded = self.forward(data, 'test')
                embedding = torch.cat((embedding, embedded), 0)

        embedding = embedding.cpu().numpy()  # embedding.t().cpu().numpy()
        label = torch.tensor(dataset.targets).cpu().numpy()

        del data_loader, embedded, data
        torch.cuda.empty_cache()  # PyTorch thing

        return embedding, label
    
    def get_loss_type(self):
        return self.loss


class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, data, device, yield_labels=True):
        self.data = data
        self.device = device
        self.yield_labels = yield_labels

    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        if self.yield_labels:
            for data, l in self.data:
                yield to_device(data, self.device), l
        else:
            for data in self.data:
                yield to_device(data, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.data)
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]

    return data.to(device, non_blocking=True)

def plot_train(epochs, train, val, type):
    plt.xlabel("epochs")
    plt.ylabel("loss")
    plt.plot(np.arange(epochs), train, color = 'r', linestyle = '--', label = 'train')
    plt.plot(np.arange(epochs), val, color = 'b', linestyle = '--', label = 'validation')
    plt.savefig('sample/'+ type + '_train.png')

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


def my_train(model, batch_size = 32, epochs = 40, max_lr= 0.1, weight_decay=0.0, grad_clip=None, opt_func=torch.optim.SGD, dataset = 'CIFAR10'):
    if dataset == 'CIFAR10':
        preprocess = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)) ])
        train_dataset = CIFAR10(root='./datasets/', train=True, transform = preprocess, download='False')
    elif dataset == 'CIFAR100':
        preprocess = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5070751592371323, 0.48654887331495095, 0.4409178433670343), (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)) ])
        train_dataset = CIFAR100(root='./datasets/', train=True, transform = preprocess, download='False')
    elif dataset == 'CUB':
        preprocess = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomCrop(224, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
        ])

        train_dataset = CUB(path='./datasets/CUB_200_2011', train=True, transform=preprocess)

    val_size = int(len(train_dataset) * 0.10)
    train_size = len(train_dataset) - val_size
    train_ds, val_ds = random_split(train_dataset, [train_size, val_size])
    

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)

    torch.cuda.empty_cache()
    # history = []
    
    # Set up cutom optimizer with weight decay
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)
    # Set up learning rate scheduler
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.20, patience=3, verbose=True)
    train_mean_loss = []
    val_mean_loss = []
    for epoch in range(epochs):
        model.train()  # tells the model is in training mode, so batchnorm, dropout and all the ohter layer that have a training mode should get to the training mode
        train_losses = []
        lrs = []
        # Training Phase
        for img, label in train_loader:
            img = img.to(device)
            optimizer.zero_grad()  # Reset the gradients
            loss = model.training_step(img, label)
            train_losses.append(loss.item())
            loss.backward()  # Compute gradients
            torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=10, norm_type=2)
            # Gradient clipping
            if grad_clip:
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)

            optimizer.step()  # Adjust the weights

        # Record & update learning rate
        mean_loss = sum(train_losses) / len(train_losses)
        lrs.append(get_lr(optimizer))
        sched.step(mean_loss)

        # Validation phase
        result = model.evaluate_step(val_loader)
        result['train_loss'] = mean_loss
        result['lrs'] = lrs
        print(f"Epoch [{epoch + 1}/{epochs}], last_lr: {lrs[-1]:.5f}, train_loss: {mean_loss:.4f}, val_loss: {result['val_loss']:.4f}") #,val_acc: {result['val_acc']:.4f}")
        train_mean_loss.append(mean_loss)
        val_mean_loss.append(result['val_loss'])
    plot_train(epochs, train_mean_loss, val_mean_loss, model.get_loss_type())
    
    

if __name__ =='__main__':
    print('start')
    epochs = 50
    max_lr = 0.1
    batch_size = 32
    grad_clip = 0.1
    weight_decay = 1e-4
    opt_func = torch.optim.SGD
    loss = 'crossentropy' #'mse'
    dataset = 'CUB' #dataset = 'CIFAR100'
    n_class = 200 #n_class = 100
    size = 18
    model = to_device(OntoEncoder(size = size, n_class = n_class, loss=loss),device)
    my_train(model, batch_size, epochs, max_lr, weight_decay=weight_decay, 
          grad_clip=grad_clip, opt_func=opt_func,dataset=dataset)
    print("\nSAVING MODEL\n")
    File_name = "checkpoint/" + str(loss) +"_"+ "resnet"+ str(size) + "_"+str(dataset) + ".pth"
    torch.save(model.state_dict(), File_name)
    



